# this file transfers jittor checkpoint to that of pytorch

from symbol import parameters
import jittor as jt
import torch

load_path = "/data/share/leixy/ccnet_jittor/ckpt/cityscapes-resnet101-0.01-cca_deepsup/epoch_35.pkl"
save_path = "./ckpt/cityscapes-resnet101-0.01-cca_deepsup-pytorch/epoch_35.pth"
# example_path = "/data/share/leixy/JSeg/resnet101_v1c-e67eebb6.pth"

parameters = jt.load(load_path)
for parameter in parameters.keys():
    parameters[parameter] = torch.tensor(parameters[parameter])
# print(parameters.keys())
# parameters = {"state_dict":parameters}
torch.save(parameters, save_path)